load("@global_pybind11_bazel//:build_defs.bzl", "pybind_extension", "pybind_library")
load("//:bazel/pybind11_defs.bzl", "py_extension")
load("@org_tensorflow//tensorflow:tensorflow.bzl", "tf_cc_binary")
load(
    "@org_tensorflow//tensorflow/core/platform:build_config.bzl",
    "tf_proto_library",
)

load("@rules_proto//proto:defs.bzl", "proto_library")
load("@rules_cc//cc:defs.bzl", "cc_binary", "cc_proto_library")
load("@com_github_grpc_grpc//bazel:cc_grpc_library.bzl", "cc_grpc_library")
CC_OPTS = ["-Wall", "-Werror", "-Wextra"]
py_binary(
    main = "@org_tensorflow//:configure.py",
    name = "tf_config",
    srcs = [
        "@org_tensorflow//:configure.py",
    ],
    data = [
        "@org_tensorflow//:tools/tf_env_collect.sh",
        "@org_tensorflow//:.bazelrc",
        "@org_tensorflow//third_party/gpus:find_cuda_config.py",
    ]
)

GEESIBLING_THIRD_COMMON_LIBS = ["@fmt", "@result", "@range-v3", "@spdlog//:spdlog_header_only"]

cc_library(
    name = "ThirdCommon",
    deps = GEESIBLING_THIRD_COMMON_LIBS,
)

cc_library(
    name = "CommonHeaders",
    hdrs = glob(["ccsrc/common/include/**/*.hpp"]),
    includes = ["ccsrc/common/include"],
    deps = ["ThirdCommon"]
)

cc_library(
    name = "DistributedIRHeaders",
    hdrs = glob(["ccsrc/DistributedIR/include/**/*.hpp"]),
    includes = ["ccsrc/DistributedIR/include"],
    deps = ["CommonHeaders"]
)

cc_library(
    name = "TFAdapterHeaders",
    hdrs = glob(
        [
            "ccsrc/adapters/include/adapters/tensorflow/**/*.hpp",
            "ccsrc/adapters/include/adapters/tensorflow/**/*.h"
        ]
    ),
    includes = ["ccsrc/adapters/include"],
    strip_include_prefix = "ccsrc/adapters/include/",
    deps = ["DistributedIRHeaders"]
)

cc_library(
    name = "TFAdapterPlacementPass",
    hdrs = ["ccsrc/adapters/include/adapters/tensorflow/pass/placement_pass.h"],
    srcs = ["ccsrc/adapters/tensorflow/pass/placement_pass.cc"],
    deps = [
        "TFAdapterHeaders",
        "FDDPSPolicy",
        "SGPPolicy",
        "@org_tensorflow//tensorflow/core/common_runtime:optimization_registry",
        "@org_tensorflow//tensorflow/core/grappler/clusters:virtual_cluster",
	    "@jsoncpp_git//:jsoncpp",
    ],
    includes = ["ccsrc/adapters/include"],
    alwayslink = 1,
)

cc_library(
    name = "DistributedIR",
    srcs = glob(["ccsrc/DistributedIR/*.cc"]),
    deps = ["DistributedIRHeaders"],
    copts = select({
        "//conditions:default": CC_OPTS,
    }),
)
cc_library(
    name = "CostGraphHeaders",
    hdrs = glob(["ccsrc/cost_graph/include/**/*.hpp"]),
    includes = ["ccsrc/cost_graph/include"],
    deps = ["DistributedIRHeaders"]
)


cc_library(
    name = "PolicyHeaders",
    hdrs = glob(["ccsrc/policy/include/**/*.hpp"]),
    includes = ["ccsrc/policy/include"],
    deps = ["CommonHeaders", "CostGraphHeaders", "@global_pybind11//:pybind11_embed", "ClusterHeaders"]
)

cc_library(
    name = "ClusterHeaders",
    hdrs = glob(["ccsrc/cluster/include/**/*.hpp"]),
    includes = ["ccsrc/cluster/include"],
    deps = ["CommonHeaders"]
)


cc_library(
    name = "FDDPSPolicy",
    srcs = glob(["ccsrc/policy/fd-dps/*.cc"]),
    deps = ["PolicyHeaders"],
    copts = select({
        "//conditions:default": CC_OPTS,
    }),
)

cc_library(
    name = "SGPPolicy",
    srcs = glob(["ccsrc/policy/sgp/*.cc"]),
    deps = ["PolicyHeaders"],
    copts = select({
        "//conditions:default": CC_OPTS,
    }),
)

py_extension(
    name = "_graph",
    srcs = glob(["ccsrc/DistributedIR/python/*.cc"]),
    deps = ["DistributedIR"],
)


tf_cc_binary(
    name = "tf-run-graph",
    srcs = [
        "tools/tf_run_graph.cc",
    ],
    linkopts = select({
        "//conditions:default": ["-lm"],
    }),
    deps = select({
        "//conditions:default": [
            "@org_tensorflow//tensorflow/core:core_cpu",
            "@org_tensorflow//tensorflow/core:framework",
            "@org_tensorflow//tensorflow/core:framework_internal",
            "@org_tensorflow//tensorflow/core:lib",
            "@org_tensorflow//tensorflow/core:protos_all_cc",
            "@org_tensorflow//tensorflow/core:tensorflow",
        ],
    }),
)

